import einops
def distance(query,memory):
    """calculate the distance between batched query and a static memory 

    :param query:  b n c -> b n 1 c
    :param memory: n B c -> 1 n B c
    :returns:      b n B

    """
    query = einops.rearrange(query, "b n c -> b n 1 c")
    memory = einops.rearrange(memory, "n B c -> 1 n B c")
    dist = (query - memory).pow(2).sum(3)  # b n B
    return dist

